// Copyright 2025 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Largely generated by Gemini 2.5 Pro and the following prompt:
//
// Please write a Go program which can parse a .go source file and then
// transform its AST in-place with the following transformations:
//
// FROM: errors.Reason(<msg>).Err()
// TO:   errors.New(<msg>)
//
// FROM: errors.Reason(<msg>).Tag(<tag>).Err()
// TO:   <tag>.Apply(errors.New(<msg>))
//
// FROM: errors.Reason(<msg>, <args...>).Tag(<tag>).Err()
// TO:   <tag>.Apply(errors.Fmt(<msg>, <args...>))
//
// FROM: errors.Reason(<msg>, <args...>).Err()
// TO:   errors.Fmt(<msg>, <args...>)
//
// FROM: errors.Annotate(<err>, "<format>", <args...>).Tag(<tag>).Err()
// TO:   <tag>.Apply(errors.Fmt("<format>: %w", <args...>, <err>))
//
// FROM: errors.Annotate(<err>, "<format>", <args...>).Err()
// TO:   errors.Fmt("<format>: %w", <args...>, <err>)
//
// FROM: errors.New(<message>, <tag>)
// TO:   <tag>.Apply(errors.New(<message>))
//
//
// NOTE: If there are multiple Tags, apply them in right-to-left order (so,
// `...Tag(A).Tag(B)...` turns into `B.Apply(A.Apply(...))`).
//
// Use "golang.org/x/tools/go/ast/astutil" as extra utility methods for the
// implementation in additon to "go/parser" and "go/ast" from the standard
// library.
//
// ... output ...
//
// Adjust the code so that it handles multiple tags provided to
// `errors.New(<message>, <tag>...)`, eg. `errors.New(<message>, tagA, tagB)`
// should turn into `tagB.Apply(tagA.Apply(errors.New(<message>)))`

// Program error_refactor applies an AST transformation to Go sources using the
// old `go.chromium.org/luci/common/errors` annotation functionality to make
// them use the simpler `errors.New` and `errors.Fmt` of the same package, and
// directly use any errtags instead of indirectly via the annotation apis.
//
// This is a temporary transitional program and will be discarded after luci-go
// and infra repos have been converted.
package main

import (
	"bufio"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
	"log"
	"os"
	"strings"

	"golang.org/x/tools/go/ast/astutil"
)

// unwrapMethodCall attempts to identify if expr is a method call like X.Method(Args...).
// It returns the receiver (X), method name, arguments, and a boolean indicating success.
func unwrapMethodCall(expr ast.Expr) (receiver ast.Expr, methodName string, args []ast.Expr, ok bool) {
	call, isCall := expr.(*ast.CallExpr)
	if !isCall {
		return nil, "", nil, false
	}
	sel, isSel := call.Fun.(*ast.SelectorExpr)
	if !isSel {
		return nil, "", nil, false
	}
	return sel.X, sel.Sel.Name, call.Args, true
}

// unwrapPackageFuncCall attempts to identify if expr is a package function call like pkg.Func(Args...).
// It checks if the package identifier's name matches pkgName.
// It returns the function name, arguments, and a boolean indicating success.
func unwrapPackageFuncCall(expr ast.Expr, pkgName string) (funcName string, args []ast.Expr, ok bool) {
	call, isCall := expr.(*ast.CallExpr)
	if !isCall {
		return "", nil, false
	}
	sel, isSel := call.Fun.(*ast.SelectorExpr)
	if !isSel {
		return "", nil, false
	}
	pkgIdent, isIdent := sel.X.(*ast.Ident)
	if !isIdent || pkgIdent.Name != pkgName {
		return "", nil, false
	}
	return sel.Sel.Name, call.Args, true
}

// postVisitFunc is the post-order traversal function for astutil.Apply.
// This is where the AST transformations happen.
func postVisitFunc(c *astutil.Cursor) bool {
	node := c.Node()
	callExpr, ok := node.(*ast.CallExpr)
	if !ok {
		return true // Not a call expression, continue.
	}

	// Rule 7 (extended): errors.New(<message>, <tagA>, <tagB>...) -> <tagN>.Apply(...<tagB>.Apply(<tagA>.Apply(errors.New(<message>)))...)
	// The rightmost tag in the argument list becomes the outermost .Apply()
	funcName, args, isPkgCall := unwrapPackageFuncCall(callExpr, "errors")
	if isPkgCall && funcName == "New" && len(args) >= 2 { // Must have at least <message> and one <tag>
		messageArg := args[0]
		tagArgsList := args[1:] // This slice contains tagA, tagB, ...

		// Base call: errors.New(<message>)
		currentAppliedExpr := ast.Expr(&ast.CallExpr{
			Fun: &ast.SelectorExpr{
				X:   ast.NewIdent("errors"),
				Sel: ast.NewIdent("New"),
			},
			Args: []ast.Expr{messageArg},
		})

		// Apply each tag argument.
		// For errors.New(msg, tagA, tagB), tagArgsList is [tagA, tagB].
		// Loop iteration 1 (individualTagArg = tagA): currentAppliedExpr = tagA.Apply(errors.New(msg))
		// Loop iteration 2 (individualTagArg = tagB): currentAppliedExpr = tagB.Apply(tagA.Apply(errors.New(msg)))
		// This makes the rightmost tag argument the outermost .Apply call.
		for _, individualTagArg := range tagArgsList {
			currentAppliedExpr = &ast.CallExpr{
				Fun: &ast.SelectorExpr{
					X:   individualTagArg,
					Sel: ast.NewIdent("Apply"),
				},
				Args: []ast.Expr{currentAppliedExpr},
			}
		}
		c.Replace(currentAppliedExpr)
		return true // Transformation applied for this node.
	}

	// --- Handle patterns ending in .Err() ---
	// The current node (callExpr) must be the .Err() call itself.
	receiverOfErr, methodName, methodArgs, isMethodCall := unwrapMethodCall(callExpr)
	if !isMethodCall || methodName != "Err" || len(methodArgs) != 0 {
		return true // Not an .Err() call pattern we're interested in, or it has arguments.
	}

	// receiverOfErr is the expression on which .Err() was called.
	// e.g., errors.Reason(<msg>), or errors.Reason(<msg>).Tag(<tag>)

	peeledTagsArgs := []ast.Expr{} // Stores <tag> arguments, rightmost in source appears first.
	currentExpr := receiverOfErr

	// Peel off .Tag(<tag>) calls from right to left (as appearing in source).
	for {
		receiverOfTag, tagName, tagArgsFromCall, isTagCall := unwrapMethodCall(currentExpr)
		if isTagCall && tagName == "Tag" && len(tagArgsFromCall) == 1 {
			peeledTagsArgs = append(peeledTagsArgs, tagArgsFromCall[0]) // Store the <tag> argument
			currentExpr = receiverOfTag                                 // Move to the expression before this .Tag()
		} else {
			break // No more .Tag() calls in the chain or malformed.
		}
	}
	// Example: errors.Reason().Tag(A).Tag(B).Err()
	// peeledTagsArgs will be [arg_B, arg_A]

	// Now, currentExpr should be the base call, e.g., errors.Reason(...) or errors.Annotate(...)
	var baseTransformedCall ast.Expr
	coreFuncName, coreArgs, isCorePkgCall := unwrapPackageFuncCall(currentExpr, "errors")

	if !isCorePkgCall {
		// If not an errors.Func() call, maybe it's a different pattern not handled or already transformed.
		return true
	}

	// Determine the base transformation (errors.New or errors.Fmt)
	switch coreFuncName {
	case "Reason":
		if len(coreArgs) == 0 {
			return true // Invalid errors.Reason() call with no arguments.
		}
		if len(coreArgs) == 1 { // errors.Reason(<msg>)
			baseTransformedCall = &ast.CallExpr{
				Fun:  &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("New")},
				Args: coreArgs, // Contains just <msg>
			}
		} else { // errors.Reason(<msg>, <args...>)
			baseTransformedCall = &ast.CallExpr{
				Fun:  &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Fmt")},
				Args: coreArgs, // Contains <msg>, <args...>
			}
		}
	case "Annotate":
		if len(coreArgs) < 2 {
			return true // Invalid errors.Annotate() call, needs at least <err> and "<format>".
		}
		errArgForAnnotate := coreArgs[0]
		formatArg, isFormatLit := coreArgs[1].(*ast.BasicLit)
		if !isFormatLit || formatArg.Kind != token.STRING {
			return true // Format argument must be a string literal for this transformation.
		}

		originalFormatValue := formatArg.Value // e.g., "\"some format\""
		if !strings.HasPrefix(originalFormatValue, "\"") || !strings.HasSuffix(originalFormatValue, "\"") || len(originalFormatValue) < 2 {
			return true // Malformed string literal.
		}
		// Append ": %w" before the closing quote.
		newFormatStringValue := originalFormatValue[:len(originalFormatValue)-1] + ": %w\""

		newFormatLit := &ast.BasicLit{Kind: token.STRING, Value: newFormatStringValue, ValuePos: formatArg.ValuePos}

		var annotateRemainingArgs []ast.Expr
		if len(coreArgs) > 2 {
			annotateRemainingArgs = coreArgs[2:]
		}

		newFmtArgs := []ast.Expr{newFormatLit}                    // New format string
		newFmtArgs = append(newFmtArgs, annotateRemainingArgs...) // Original <args...>
		newFmtArgs = append(newFmtArgs, errArgForAnnotate)        // Original <err> for %w

		baseTransformedCall = &ast.CallExpr{
			Fun:  &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Fmt")},
			Args: newFmtArgs,
		}
	default:
		return true // Base call is not errors.Reason or errors.Annotate.
	}

	// Apply the collected tags from .Tag() calls.
	// For ...Tag(A).Tag(B).Err(), peeledTagsArgs is [arg_B, arg_A].
	// We need B.Apply(A.Apply(baseTransformedCall)).
	// So, iterate peeledTagsArgs from its end (arg_A) to its start (arg_B).
	finalCallToBuild := baseTransformedCall
	for i := len(peeledTagsArgs) - 1; i >= 0; i-- {
		tagToApply := peeledTagsArgs[i]
		finalCallToBuild = &ast.CallExpr{
			Fun: &ast.SelectorExpr{
				X:   tagToApply, // This is the <tag> expression
				Sel: ast.NewIdent("Apply"),
			},
			Args: []ast.Expr{finalCallToBuild},
		}
	}

	c.Replace(finalCallToBuild)
	return true // Transformation applied for this .Err() chain.
}

func rewrite(filePath string) {
	fset := token.NewFileSet()
	fileNode, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments)
	if err != nil {
		log.Fatalf("Failed to parse file %s: %v", filePath, err)
	}

	// Apply the transformations using astutil.Apply
	astutil.Apply(fileNode, nil, postVisitFunc)

	// Format and print the modified AST
	outf, err := os.Create(filePath)
	if err != nil {
		log.Fatal("Failed to open outfile: %w", err)
	}
	defer outf.Close()
	if err := format.Node(outf, fset, fileNode); err != nil {
		log.Fatalf("Failed to format AST: %v", err)
	}

	fmt.Fprintln(os.Stderr, filePath)
}

func main() {
	if len(os.Args) != 1 {
		log.Fatal("Usage: go-errors-transformer < <list of file paths>")
	}
	scn := bufio.NewScanner(os.Stdin)
	for scn.Scan() {
		rewrite(strings.TrimSpace(scn.Text()))
	}
	if scn.Err() != nil {
		log.Fatal(scn.Err())
	}
}
